// Copyright 2015/2016 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.

package main

import (
	"bytes"
	"flag"
	"fmt"
	"go/format"
	"os"
	"path/filepath"
	"reflect"
	"sort"
	"strings"
	"sync"
	"text/template"

	"github.com/google/syzkaller/pkg/ast"
	"github.com/google/syzkaller/pkg/compiler"
	"github.com/google/syzkaller/pkg/hash"
	"github.com/google/syzkaller/pkg/osutil"
	"github.com/google/syzkaller/pkg/tool"
	"github.com/google/syzkaller/prog"
	"github.com/google/syzkaller/sys/generated"
	"github.com/google/syzkaller/sys/targets"
)

type SyscallData struct {
	Name     string
	CallName string
	NR       int32
	NeedCall bool
	Attrs    []uint64
}

type Define struct {
	Name  string
	Value string
}

type ArchData struct {
	Revision   string
	ForkServer int
	GOARCH     string
	PageSize   uint64
	NumPages   uint64
	DataOffset uint64
	Calls      []SyscallData
	Defines    []Define
}

type OSData struct {
	GOOS  string
	Archs []ArchData
}

type CallPropDescription struct {
	Type string
	Name string
}

type TemplateData struct {
	Notice    string
	OSes      []OSData
	CallAttrs []string
	CallProps []CallPropDescription
}

var srcDir = flag.String("src", "", "path to root of syzkaller source dir")
var outDir = flag.String("out", "", "path to out dir")

func main() {
	defer tool.Init()()

	// Cleanup old files in the case set of architectures has chnaged.
	allFiles, err := filepath.Glob(filepath.Join(*outDir, "sys", generated.Glob()))
	if err != nil {
		tool.Failf("failed to glob: %v", err)
	}
	for _, file := range allFiles {
		os.Remove(file)
	}

	// Also remove old generated files since they will break build.
	// TODO: remove this after some time after 2025-01-23.
	oldFiles, err := filepath.Glob(filepath.Join(*outDir, "sys", "*", "gen", "*"))
	if err != nil {
		tool.Failf("failed to glob: %v", err)
	}
	for _, file := range oldFiles {
		os.Remove(file)
	}

	var OSList []string
	for OS := range targets.List {
		OSList = append(OSList, OS)
	}
	sort.Strings(OSList)

	data := &TemplateData{
		Notice: "Automatically generated by syz-sysgen; DO NOT EDIT.",
	}
	for _, OS := range OSList {
		descriptions := ast.ParseGlob(filepath.Join(*srcDir, "sys", OS, "*.txt"), nil)
		if descriptions == nil {
			os.Exit(1)
		}
		constFile := compiler.DeserializeConstFile(filepath.Join(*srcDir, "sys", OS, "*.const"), nil)
		if constFile == nil {
			os.Exit(1)
		}

		var archs []string
		for arch := range targets.List[OS] {
			archs = append(archs, arch)
		}
		sort.Strings(archs)

		var jobs []*Job
		for _, arch := range archs {
			target := targets.List[OS][arch]
			constInfo := compiler.ExtractConsts(descriptions, target, nil)
			if OS == targets.TestOS {
				// The ConstFile object provides no guarantees re concurrent read-write,
				// so let's patch it before we start goroutines.
				compiler.FabricateSyscallConsts(target, constInfo, constFile)
			}
			jobs = append(jobs, &Job{
				Target:      target,
				Unsupported: make(map[string]bool),
				ConstInfo:   constInfo,
			})
		}
		sort.Slice(jobs, func(i, j int) bool {
			return jobs[i].Target.Arch < jobs[j].Target.Arch
		})
		var wg sync.WaitGroup
		wg.Add(len(jobs))

		for _, job := range jobs {
			go func() {
				defer wg.Done()
				processJob(job, descriptions, constFile)
			}()
		}
		wg.Wait()

		var syscallArchs []ArchData
		unsupported := make(map[string]int)
		for _, job := range jobs {
			if !job.OK {
				fmt.Printf("compilation of %v/%v target failed:\n", job.Target.OS, job.Target.Arch)
				for _, msg := range job.Errors {
					fmt.Print(msg)
				}
				os.Exit(1)
			}
			syscallArchs = append(syscallArchs, job.ArchData)
			for u := range job.Unsupported {
				unsupported[u]++
			}
		}
		data.OSes = append(data.OSes, OSData{
			GOOS:  OS,
			Archs: syscallArchs,
		})

		for what, count := range unsupported {
			if count == len(jobs) {
				tool.Failf("%v is unsupported on all arches (typo?)", what)
			}
		}
	}

	attrs := reflect.TypeOf(prog.SyscallAttrs{})
	for i := 0; i < attrs.NumField(); i++ {
		data.CallAttrs = append(data.CallAttrs, prog.CppName(attrs.Field(i).Name))
	}

	props := prog.CallProps{}
	props.ForeachProp(func(name, _ string, value reflect.Value) {
		data.CallProps = append(data.CallProps, CallPropDescription{
			Type: value.Kind().String(),
			Name: prog.CppName(name),
		})
	})

	sort.Slice(data.OSes, func(i, j int) bool {
		return data.OSes[i].GOOS < data.OSes[j].GOOS
	})

	writeTemplate(filepath.Join(*outDir, "sys", "register.go"), registerTempl, data)
	writeTemplate(filepath.Join(*outDir, "executor", "defs.h"), defsTempl, data)
	writeTemplate(filepath.Join(*outDir, "executor", "syscalls.h"), syscallsTempl, data)
}

type Job struct {
	Target      *targets.Target
	OK          bool
	Errors      []string
	Unsupported map[string]bool
	ArchData    ArchData
	ConstInfo   map[string]*compiler.ConstInfo
	Revision    string
}

func processJob(job *Job, descriptions *ast.Description, constFile *compiler.ConstFile) {
	var flags []prog.FlagDesc
	for _, decl := range descriptions.Nodes {
		switch n := decl.(type) {
		case *ast.IntFlags:
			var flag prog.FlagDesc
			flag.Name = n.Name.Name
			for _, val := range n.Values {
				flag.Values = append(flag.Values, val.Ident)
			}
			flags = append(flags, flag)
		}
	}

	eh := func(pos ast.Pos, msg string) {
		job.Errors = append(job.Errors, fmt.Sprintf("%v: %v\n", pos, msg))
	}
	consts := constFile.Arch(job.Target.Arch)
	constArr := make([]prog.ConstValue, 0, len(consts))
	for name, val := range consts {
		constArr = append(constArr, prog.ConstValue{Name: name, Value: val})
	}
	sort.Slice(constArr, func(i, j int) bool {
		return constArr[i].Name < constArr[j].Name
	})

	prg := compiler.Compile(descriptions, consts, job.Target, eh)
	if prg == nil {
		return
	}
	for what := range prg.Unsupported {
		job.Unsupported[what] = true
	}

	desc := &generated.Desc{
		Syscalls:  prg.Syscalls,
		Resources: prg.Resources,
		Types:     prg.Types,
		Consts:    constArr,
		Flags:     flags,
	}
	data, err := generated.Serialize(desc)
	if err != nil {
		tool.Fail(err)
	}
	sysFile := filepath.Join(*outDir, "sys", generated.FileName(job.Target.OS, job.Target.Arch))
	writeFile(sysFile, data)

	job.ArchData = generateExecutorSyscalls(job.Target, prg.Syscalls, hash.String(data))

	// Don't print warnings, they are printed in syz-check.
	job.Errors = nil
	// But let's fail on always actionable errors.
	if job.Target.OS != targets.Fuchsia {
		// There are too many broken consts on Fuchsia.
		constsAreAllDefined(constFile, job.ConstInfo, eh)
	}
	job.OK = len(job.Errors) == 0
}

func generateExecutorSyscalls(target *targets.Target, syscalls []*prog.Syscall, rev string) ArchData {
	data := ArchData{
		Revision:   rev,
		GOARCH:     target.Arch,
		PageSize:   target.PageSize,
		NumPages:   target.NumPages,
		DataOffset: target.DataOffset,
	}
	if target.ExecutorUsesForkServer {
		data.ForkServer = 1
	}
	defines := make(map[string]string)
	for _, c := range syscalls {
		var attrVals []uint64
		attrs := reflect.ValueOf(c.Attrs)
		last := -1
		for i := 0; i < attrs.NumField(); i++ {
			attr := attrs.Field(i)
			val := uint64(0)
			switch attr.Type().Kind() {
			case reflect.Bool:
				if attr.Bool() {
					val = 1
				}
			case reflect.Uint64:
				val = attr.Uint()
			case reflect.String:
				continue
			default:
				panic("unsupported syscall attribute type")
			}
			attrVals = append(attrVals, val)
			if val != 0 {
				last = i
			}
		}
		data.Calls = append(data.Calls, newSyscallData(target, c, attrVals[:last+1]))
		// Some syscalls might not be present on the compiling machine, so we
		// generate definitions for them.
		if target.HasCallNumber(c.CallName) && target.NeedSyscallDefine(c.NR) {
			defines[target.SyscallPrefix+c.CallName] = fmt.Sprintf("%d", c.NR)
		}
	}
	sort.Slice(data.Calls, func(i, j int) bool {
		return data.Calls[i].Name < data.Calls[j].Name
	})
	// Get a sorted list of definitions.
	defineNames := []string{}
	for key := range defines {
		defineNames = append(defineNames, key)
	}
	sort.Strings(defineNames)
	for _, key := range defineNames {
		data.Defines = append(data.Defines, Define{key, defines[key]})
	}
	return data
}

func newSyscallData(target *targets.Target, sc *prog.Syscall, attrs []uint64) SyscallData {
	callName, patchCallName := target.SyscallTrampolines[sc.Name]
	if !patchCallName {
		callName = sc.CallName
	}
	return SyscallData{
		Name:     sc.Name,
		CallName: callName,
		NR:       int32(sc.NR),
		NeedCall: (!target.HasCallNumber(sc.CallName) || patchCallName) &&
			// These are declared in the compiler for internal purposes.
			!strings.HasPrefix(sc.Name, "syz_builtin"),
		Attrs: attrs,
	}
}

func writeTemplate(file string, templ *template.Template, data any) {
	buf := new(bytes.Buffer)
	if err := templ.Execute(buf, data); err != nil {
		tool.Failf("failed to execute template: %v", err)
	}
	contents := buf.Bytes()
	if strings.HasSuffix(file, ".go") {
		var err error
		contents, err = format.Source(contents)
		if err != nil {
			tool.Failf("failed to format generated source: %v", err)
		}
	}
	writeFile(file, contents)
}

func writeFile(file string, data []byte) {
	if current, err := os.ReadFile(file); err == nil && bytes.Equal(data, current) {
		return
	}
	osutil.MkdirAll(filepath.Dir(file))
	if err := osutil.WriteFile(file, data); err != nil {
		tool.Failf("failed to write output file: %v", err)
	}
}

// nolint: lll
var registerTempl = template.Must(template.New("register").Parse(`// {{.Notice}}

package sys

import (
	"embed"

	"github.com/google/syzkaller/sys/generated"
	{{range $os := $.OSes}}
	"github.com/google/syzkaller/sys/{{$os.GOOS}}"{{end}}
)

//go:embed gen/*.gob.flate
var files embed.FS

func init() {
	{{range $os := $.OSes}}{{range $arch := $os.Archs}}generated.Register("{{$os.GOOS}}", "{{$arch.GOARCH}}", "{{$arch.Revision}}", {{$os.GOOS}}.InitTarget, files)
	{{end}}{{end}}
}
`))

var defsTempl = template.Must(template.New("defs").Parse(`// {{.Notice}}

struct call_attrs_t { {{range $attr := $.CallAttrs}}
	uint64_t {{$attr}};{{end}}
};

struct call_props_t { {{range $attr := $.CallProps}}
	{{$attr.Type}} {{$attr.Name}};{{end}}
};

#define read_call_props_t(var, reader) { \{{range $attr := $.CallProps}}
	(var).{{$attr.Name}} = ({{$attr.Type}})(reader); \{{end}}
}

{{range $os := $.OSes}}
#if GOOS_{{$os.GOOS}}
#define GOOS "{{$os.GOOS}}"
{{range $arch := $os.Archs}}
#if GOARCH_{{$arch.GOARCH}}
#define GOARCH "{{.GOARCH}}"
#define SYZ_REVISION "{{.Revision}}"
#define SYZ_EXECUTOR_USES_FORK_SERVER {{.ForkServer}}
#define SYZ_PAGE_SIZE {{.PageSize}}
#define SYZ_NUM_PAGES {{.NumPages}}
#define SYZ_DATA_OFFSET {{.DataOffset}}
{{range $c := $arch.Defines}}#ifndef {{$c.Name}}
#define {{$c.Name}} {{$c.Value}}
#endif
{{end}}#endif
{{end}}
#endif
{{end}}
`))

// nolint: lll
var syscallsTempl = template.Must(template.New("syscalls").Parse(`// {{.Notice}}
// clang-format off
{{range $os := $.OSes}}
#if GOOS_{{$os.GOOS}}
{{range $arch := $os.Archs}}
#if GOARCH_{{$arch.GOARCH}}
const call_t syscalls[] = {
{{range $c := $arch.Calls}}    {"{{$c.Name}}", {{$c.NR}}{{if or $c.Attrs $c.NeedCall}}, { {{- range $attr := $c.Attrs}}{{$attr}}, {{end}}}{{end}}{{if $c.NeedCall}}, (syscall_t){{$c.CallName}}{{end}}},
{{end}}};
#endif
{{end}}
#endif
{{end}}
`))
